{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_7_neuromorphic_datasets.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "47d5313e-c29d-4581-a9c7-a45122337069",
      "metadata": {
        "id": "47d5313e-c29d-4581-a9c7-a45122337069"
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"300\">](https://github.com/jeshraghian/snntorch/) \n",
        "[<img src='https://github.com/neuromorphs/tonic/blob/develop/docs/_static/tonic-logo-white.png?raw=true' width=\"200\">](https://github.com/neuromorphs/tonic/)\n",
        "\n",
        "\n",
        "# Neuromorphic Datasets with Tonic + snnTorch\n",
        "## Tutorial 7\n",
        "### By Gregor Lenz (https://lenzgregor.com) and Jason K. Eshraghian (www.jasoneshraghian.com)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/examples/tutorial_7_neuromorphic_datasets.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "oll2NNFeG1NG",
      "metadata": {
        "id": "oll2NNFeG1NG"
      },
      "source": [
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ClgsZMOfBVby",
      "metadata": {
        "id": "ClgsZMOfBVby"
      },
      "source": [
        "# Introduction\n",
        "In this tutorial, you will:\n",
        "* Learn how to load neuromorphic datasets using [Tonic](https://github.com/neuromorphs/tonic)\n",
        "* Make use of caching to speed up dataloading\n",
        "* Train a CSNN with the [Neuromorphic-MNIST](https://tonic.readthedocs.io/en/latest/datasets.html#n-mnist) Dataset\n",
        "\n",
        "If running in Google Colab:\n",
        "* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`\n",
        "* Next, install the latest PyPi distribution of snnTorch and Tonic by clicking into the following cell and pressing `Shift+Enter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "hDnIEHOKB8LD",
      "metadata": {
        "id": "hDnIEHOKB8LD"
      },
      "outputs": [],
      "source": [
        "!pip install tonic --quiet \n",
        "!pip install snntorch --quiet"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e93694d9-0f0a-46a0-b17f-c04ac9b73a63",
      "metadata": {
        "id": "e93694d9-0f0a-46a0-b17f-c04ac9b73a63"
      },
      "source": [
        "# 1. Using Tonic to Load Neuromorphic Datasets\n",
        "Loading datasets from neuromorphic sensors is made super simple thanks to [Tonic](https://github.com/neuromorphs/tonic), which works much like PyTorch vision.\n",
        "\n",
        "Let's start by loading the neuromorphic version of the MNIST dataset, called [N-MNIST](https://tonic.readthedocs.io/en/latest/reference/datasets.html#n-mnist). We can have a look at some raw events to get a feel for what we're working with."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7d286ef9-5fe6-4578-a686-91559a1f81d2",
      "metadata": {
        "id": "7d286ef9-5fe6-4578-a686-91559a1f81d2"
      },
      "outputs": [],
      "source": [
        "import tonic\n",
        "\n",
        "dataset = tonic.datasets.NMNIST(save_to='./data', train=True)\n",
        "events, target = dataset[0]\n",
        "print(events)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "IbnLy107Oo3a",
      "metadata": {
        "id": "IbnLy107Oo3a"
      },
      "source": [
        "Each row corresponds to a single event, which consists of four parameters: (*x-coordinate, y-coordinate, timestamp, polarity*).\n",
        "\n",
        "* x & y co-ordinates correspond to an address in a $34 \\times 34$ grid.\n",
        "\n",
        "* The timestamp of the event is recorded in microseconds.\n",
        "\n",
        "* The polarity refers to whether an on-spike (+1) or an off-spike (-1) occured; i.e., an increase in brightness or a decrease in brightness.\n",
        "\n",
        "If we were to accumulate those events over time and plot the bins as images, it looks like this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e6d6a0b0-2a73-4dbe-9576-d06c251f0fa4",
      "metadata": {
        "id": "e6d6a0b0-2a73-4dbe-9576-d06c251f0fa4"
      },
      "outputs": [],
      "source": [
        "tonic.utils.plot_event_grid(events)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f6bcc031-d11a-4471-b3aa-335eec76d7ad",
      "metadata": {
        "id": "f6bcc031-d11a-4471-b3aa-335eec76d7ad"
      },
      "source": [
        "## 1.1 Transformations\n",
        "\n",
        "However, neural nets don't take lists of events as input. The raw data must be converted into a suitable representation, such as a tensor. We can choose a set of transforms to apply to our data before feeding it to our network. The neuromorphic camera sensor has a temporal resolution of microseconds, which when converted into a dense representation, ends up as a very large tensor. That is why we bin events into a smaller number of frames using the [ToFrame transformation](https://tonic.readthedocs.io/en/latest/reference/transformations.html#frames), which reduces temporal precision but also allows us to work with it in a dense format.\n",
        "\n",
        "* `time_window=1000` integrates events into 1000$~\\mu$s bins\n",
        "\n",
        "* Denoise removes isolated, one-off events. If no event occurs within a neighbourhood of 1 pixel across `filter_time` microseconds, the event is filtered. Smaller `filter_time` will filter more events."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "30f249be-8a65-4c1c-a21c-d561e904b4bf",
      "metadata": {
        "id": "30f249be-8a65-4c1c-a21c-d561e904b4bf"
      },
      "outputs": [],
      "source": [
        "import tonic.transforms as transforms\n",
        "\n",
        "sensor_size = tonic.datasets.NMNIST.sensor_size\n",
        "\n",
        "# Denoise removes isolated, one-off events\n",
        "# time_window\n",
        "frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000), \n",
        "                                      transforms.ToFrame(sensor_size=sensor_size, \n",
        "                                                         time_window=1000)\n",
        "                                     ])\n",
        "\n",
        "trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)\n",
        "testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "70be0b77-9405-44f6-af49-dfc4d49566c6",
      "metadata": {
        "id": "70be0b77-9405-44f6-af49-dfc4d49566c6"
      },
      "source": [
        "## 1.2 Fast Dataloading via Caching\n",
        "\n",
        "The original data is stored in a format that is slow to read. To speed up dataloading, we can make use of disk caching. That means that once files are loaded from the original file, they are written to disk in an efficient format in our cache directory. Let's compare some file reading speeds to read 100 examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3a6bf2a2-ff9f-4cdc-8cb3-02a1a9d71a11",
      "metadata": {
        "id": "3a6bf2a2-ff9f-4cdc-8cb3-02a1a9d71a11"
      },
      "outputs": [],
      "source": [
        "def load_sample_simple():\n",
        "    for i in range(100):\n",
        "        events, target = trainset[i]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1a9d3b28-b303-4a17-be78-b9918911a7cd",
      "metadata": {
        "id": "1a9d3b28-b303-4a17-be78-b9918911a7cd"
      },
      "outputs": [],
      "source": [
        "%timeit -o load_sample_simple()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b957b32f-d76b-42c0-8c1c-b6b63e84e2ef",
      "metadata": {
        "id": "b957b32f-d76b-42c0-8c1c-b6b63e84e2ef"
      },
      "source": [
        "We can decrease the time it takes to read 100 samples by using a PyTorch DataLoader in addition to disk caching."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f47e0798-5259-491a-9d3a-59b13b1b0983",
      "metadata": {
        "id": "f47e0798-5259-491a-9d3a-59b13b1b0983"
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "from tonic import CachedDataset\n",
        "\n",
        "cached_trainset = CachedDataset(trainset, cache_path='./cache/nmnist/train')\n",
        "cached_dataloader = DataLoader(cached_trainset)\n",
        "\n",
        "def load_sample_cached():\n",
        "    for i, (events, target) in enumerate(iter(cached_dataloader)):\n",
        "        if i > 99: break"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "17a0219b-4d15-4f0b-b5be-8c728b5e24a9",
      "metadata": {
        "id": "17a0219b-4d15-4f0b-b5be-8c728b5e24a9"
      },
      "outputs": [],
      "source": [
        "%timeit -o -r 20 load_sample_cached()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3831428d-0511-4fde-84d9-11d08fa45df7",
      "metadata": {
        "id": "3831428d-0511-4fde-84d9-11d08fa45df7"
      },
      "source": [
        "## 1.3 Even Faster DataLoading via Batching\n",
        "\n",
        "Now that we've reduced our loading time, we also want to use batching to make efficient use of the GPU. \n",
        "\n",
        "Because event recordings have different lengths, we are going to provide a  collation function `tonic.collation.PadTensors()` that will pad out shorter recordings to ensure all samples in a batch have the same dimensions. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5b35c7cd-d292-47cd-9203-7f31aa7f7207",
      "metadata": {
        "id": "5b35c7cd-d292-47cd-9203-7f31aa7f7207"
      },
      "outputs": [],
      "source": [
        "batch_size = 100\n",
        "trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "14b9af4f-141e-4301-8451-445957ec8707",
      "metadata": {
        "id": "14b9af4f-141e-4301-8451-445957ec8707"
      },
      "outputs": [],
      "source": [
        "def load_sample_batched():\n",
        "    events, target = next(iter(cached_dataloader))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3dc4b27a-63ac-4edc-94e9-589d548c4769",
      "metadata": {
        "id": "3dc4b27a-63ac-4edc-94e9-589d548c4769"
      },
      "outputs": [],
      "source": [
        "%timeit -o -r 10 load_sample_batched()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a82a7afd-c011-4cd6-ba04-1e7cd438bc1f",
      "metadata": {
        "id": "a82a7afd-c011-4cd6-ba04-1e7cd438bc1f"
      },
      "source": [
        "By using disk caching and a PyTorch dataloader with multithreading and batching support, we have reduced loading times to less than a tenth per sample in comparison to naively iterating over the dataset!"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2ded1bd9-e2f1-479e-899c-c6c2652e6fc9",
      "metadata": {
        "id": "2ded1bd9-e2f1-479e-899c-c6c2652e6fc9"
      },
      "source": [
        "# 2. Training our network using frames created from events"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9be82d75-69ef-4c1b-ad85-4eca84c73ccf",
      "metadata": {
        "id": "9be82d75-69ef-4c1b-ad85-4eca84c73ccf"
      },
      "source": [
        "Now let's actually train a network on the N-MNIST classification task. We start by defining our caching wrappers and dataloaders. While doing that, we're also going to apply some augmentations to the training data. The samples we receive from the cached dataset are frames, so we can make use of PyTorch Vision to apply whatever random transform we would like."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ace6cd0b-7b56-4422-b3bd-23bac65db9bd",
      "metadata": {
        "id": "ace6cd0b-7b56-4422-b3bd-23bac65db9bd"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision\n",
        "\n",
        "transform = tonic.transforms.Compose([torch.from_numpy,\n",
        "                                      torchvision.transforms.RandomRotation([-10,10])])\n",
        "\n",
        "cached_trainset = CachedDataset(trainset, transform=transform, cache_path='./cache/nmnist/train')\n",
        "\n",
        "# no augmentations for the testset\n",
        "cached_testset = CachedDataset(testset, cache_path='./cache/nmnist/test')\n",
        "\n",
        "batch_size = 128\n",
        "trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)\n",
        "testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "528fe384-a365-4b53-bbfc-ed4dd261d222",
      "metadata": {
        "id": "528fe384-a365-4b53-bbfc-ed4dd261d222"
      },
      "source": [
        "A mini-batch now has the dimensions (time steps, batch size, channels, height, width). The number of time steps will be set to that of the longest recording in the mini-batch, and all other samples will be padded with zeros to match it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c9e37337-ad4a-43d5-b429-81a18de5148e",
      "metadata": {
        "id": "c9e37337-ad4a-43d5-b429-81a18de5148e"
      },
      "outputs": [],
      "source": [
        "event_tensor, target = next(iter(trainloader))\n",
        "print(event_tensor.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "61ae5d4b-2bb3-4191-9f96-04b3c6ba4c41",
      "metadata": {
        "id": "61ae5d4b-2bb3-4191-9f96-04b3c6ba4c41"
      },
      "source": [
        "## 2.1 Defining our network\n",
        "We will use snnTorch + PyTorch to construct a CSNN, just as in the previous tutorial. The convolutional network architecture to be used is: 12C5-MP2-32C5-MP2-800FC10\n",
        "\n",
        "- 12C5 is a 5$\\times$5 convolutional kernel with 12 filters\n",
        "- MP2 is a 2$\\times$2 max-pooling function\n",
        "- 800FC10 is a fully-connected layer that maps 800 neurons to 10 outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "HpKDIkRKUIAB",
      "metadata": {
        "id": "HpKDIkRKUIAB"
      },
      "outputs": [],
      "source": [
        "import snntorch as snn\n",
        "from snntorch import surrogate\n",
        "from snntorch import functional as SF\n",
        "from snntorch import spikeplot as splt\n",
        "from snntorch import utils\n",
        "import torch.nn as nn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "107cb645-0227-4290-9e1b-25d6ae7eac87",
      "metadata": {
        "id": "107cb645-0227-4290-9e1b-25d6ae7eac87"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
        "\n",
        "# neuron and simulation parameters\n",
        "spike_grad = surrogate.fast_sigmoid(slope=75)\n",
        "beta = 0.5\n",
        "\n",
        "#  Initialize Network\n",
        "net = nn.Sequential(nn.Conv2d(2, 12, 5),\n",
        "                    nn.MaxPool2d(2),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),\n",
        "                    nn.Conv2d(12, 32, 5),\n",
        "                    nn.MaxPool2d(2),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),\n",
        "                    nn.Flatten(),\n",
        "                    nn.Linear(32*5*5, 10),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)\n",
        "                    ).to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "zPFvlqOGi_uW",
      "metadata": {
        "id": "zPFvlqOGi_uW"
      },
      "outputs": [],
      "source": [
        "# this time, we won't return membrane as we don't need it \n",
        "\n",
        "def forward_pass(net, data):  \n",
        "  spk_rec = []\n",
        "  utils.reset(net)  # resets hidden states for all LIF neurons in net\n",
        "\n",
        "  for step in range(data.size(0)):  # data.size(0) = number of time steps\n",
        "      spk_out, mem_out = net(data[step])\n",
        "      spk_rec.append(spk_out)\n",
        "  \n",
        "  return torch.stack(spk_rec)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "23569dfc-e4a7-490f-8a68-c9ade5e03028",
      "metadata": {
        "id": "23569dfc-e4a7-490f-8a68-c9ade5e03028"
      },
      "source": [
        "## 2.2 Training\n",
        "\n",
        "In the previous tutorial, Cross Entropy Loss was applied to the total spike count to maximize the number of spikes from the correct class.\n",
        "\n",
        "Another option from the `snn.functional` module is to specify the target number of spikes from correct and incorrect classes. The approach below uses the *Mean Square Error Spike Count Loss*, which aims to elicit spikes from the correct class 80\\% of the time, and 20\\% of the time from incorrect classes. Encouraging incorrect neurons to fire could be motivated to avoid dead neurons."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VocYbtD7Vwp7",
      "metadata": {
        "id": "VocYbtD7Vwp7"
      },
      "outputs": [],
      "source": [
        "optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))\n",
        "loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7xkKLsqnmzcw",
      "metadata": {
        "id": "7xkKLsqnmzcw"
      },
      "source": [
        "Training neuromorphic data is expensive as it requires sequentially iterating through many time steps (approximately 300 time steps in the N-MNIST dataset). The following simulation will take some time, so we will just stick to training across 50 iterations (which is roughly 1/10th of a full epoch). Feel free to change `num_iters` if you have more time to kill. As we are printing results at each iteration, the results will be quite noisy and will also take some time before we start to see any sort of improvement.\n",
        "\n",
        "In our own experiments, it took about 20 iterations before we saw any improvement, and after 50 iterations, managed to crack ~60% accuracy. \n",
        "\n",
        "> Warning: the following simulation will take a while. Go make yourself a coffee, or ten. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "R4GbPSdTUcUR",
      "metadata": {
        "id": "R4GbPSdTUcUR"
      },
      "outputs": [],
      "source": [
        "num_epochs = 1\n",
        "num_iters = 50\n",
        "\n",
        "loss_hist = []\n",
        "acc_hist = []\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "    for i, (data, targets) in enumerate(iter(trainloader)):\n",
        "        data = data.to(device)\n",
        "        targets = targets.to(device)\n",
        "\n",
        "        net.train()\n",
        "        spk_rec = forward_pass(net, data)\n",
        "        loss_val = loss_fn(spk_rec, targets)\n",
        "\n",
        "        # Gradient calculation + weight update\n",
        "        optimizer.zero_grad()\n",
        "        loss_val.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Store loss history for future plotting\n",
        "        loss_hist.append(loss_val.item())\n",
        " \n",
        "        print(f\"Epoch {epoch}, Iteration {i} \\nTrain Loss: {loss_val.item():.2f}\")\n",
        "\n",
        "        acc = SF.accuracy_rate(spk_rec, targets) \n",
        "        acc_hist.append(acc)\n",
        "        print(f\"Accuracy: {acc * 100:.2f}%\\n\")\n",
        "\n",
        "        # This will end training after 50 iterations by default\n",
        "        if i == num_iters:\n",
        "          break"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "YVjUzNcX0wld",
      "metadata": {
        "id": "YVjUzNcX0wld"
      },
      "source": [
        "# 3. Results\n",
        "## 3.1 Plot Test Accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "yp2aTX2_1zFG",
      "metadata": {
        "id": "yp2aTX2_1zFG"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Plot Loss\n",
        "fig = plt.figure(facecolor=\"w\")\n",
        "plt.plot(acc_hist)\n",
        "plt.title(\"Train Set Accuracy\")\n",
        "plt.xlabel(\"Iteration\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1gb1wCQb2bMd",
      "metadata": {
        "id": "1gb1wCQb2bMd"
      },
      "source": [
        "## 3.2 Spike Counter\n",
        "\n",
        "Run a forward pass on a batch of data to obtain spike recordings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "qLAfvj9D2AYd",
      "metadata": {
        "id": "qLAfvj9D2AYd"
      },
      "outputs": [],
      "source": [
        "spk_rec = forward_pass(net, data)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "VQnj40YC2hUV",
      "metadata": {
        "id": "VQnj40YC2hUV"
      },
      "source": [
        "Changing `idx` allows you to index into various samples from the simulated minibatch. Use `splt.spike_count` to explore the spiking behaviour of a few different samples. Generating the following animation will take some time.\n",
        "\n",
        "> Note: if you are running the notebook locally on your desktop, please uncomment the line below and modify the path to your ffmpeg.exe\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "oTKhuyk22M57",
      "metadata": {
        "id": "oTKhuyk22M57"
      },
      "outputs": [],
      "source": [
        "from IPython.display import HTML\n",
        "\n",
        "idx = 0\n",
        "\n",
        "fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))\n",
        "labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']\n",
        "print(f\"The target label is: {targets[idx]}\")\n",
        "\n",
        "# plt.rcParams['animation.ffmpeg_path'] = 'C:\\\\path\\\\to\\\\your\\\\ffmpeg.exe'\n",
        "\n",
        "#  Plot spike count histogram\n",
        "anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels, \n",
        "                        animate=True, interpolate=1)\n",
        "\n",
        "HTML(anim.to_html5_video())\n",
        "# anim.save(\"spike_bar.mp4\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "-iSGTq0Q3Lcm",
      "metadata": {
        "id": "-iSGTq0Q3Lcm"
      },
      "source": [
        "# Conclusion\n",
        "If you made it this far, then congratulations - you have the patience of a monk. You should now also understand how to load neuromorphic datasets using Tonic and then train a network using snnTorch. [In the next tutorial](https://snntorch.readthedocs.io/en/latest/tutorials/index.html), we will learn more advanced techniques, such as introducing long-term temporal dynamics into our SNNs.\n",
        "\n",
        "If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "h-K_DUnsMKnv",
      "metadata": {
        "id": "h-K_DUnsMKnv"
      },
      "source": [
        "# Additional Resources\n",
        "* [Check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)\n",
        "* [The Tonic GitHub project can be found here.](https://github.com/neuromorphs/tonic)\n",
        "* The N-MNIST Dataset was originally published in the following paper: [Orchard, G.; Cohen, G.; Jayawant, A.; and Thakor, N.  “Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades\", Frontiers in Neuroscience, vol.9, no.437, Oct. 2015.](https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full) \n",
        "* For further information about how N-MNIST was created, please refer to [Garrick Orchard's website here.](https://www.garrickorchard.com/datasets/n-mnist)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "include_colab_link": true,
      "name": "Copy of tutorial_5_neuromorphic_datasets.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
